"""
Patches views.

| Copyright 2017-2025, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""
from collections import defaultdict
from copy import deepcopy

from bson import ObjectId

import eta.core.utils as etau

import fiftyone.core.aggregations as foa
import fiftyone.core.dataset as fod
import fiftyone.core.fields as fof
import fiftyone.core.labels as fol
import fiftyone.core.media as fom
import fiftyone.core.odm as foo
import fiftyone.core.sample as fos
import fiftyone.core.validation as fova
import fiftyone.core.view as fov


_PATCHES_TYPES = (fol.Detections, fol.Polylines, fol.Keypoints)
_NO_MATCH_ID = ""


class _PatchView(fos.SampleView):
    @property
    def _sample_id(self):
        return ObjectId(self._doc.sample_id)

    @property
    def _frame_id(self):
        return ObjectId(self._doc.frame_id)

    def _save(self, deferred=False):
        sample_ops, frame_ops = super()._save(deferred=deferred)

        if not deferred:
            self._view._sync_source_sample(self)

        return sample_ops, frame_ops


class PatchView(_PatchView):
    """A patch in a :class:`PatchesView`.

    :class:`PatchView` instances should not be created manually; they are
    generated by iterating over :class:`PatchesView` instances.

    Args:
        doc: a :class:`fiftyone.core.odm.DatasetSampleDocument`
        view: the :class:`PatchesView` that the patch belongs to
        selected_fields (None): a set of field names that this view is
            restricted to
        excluded_fields (None): a set of field names that are excluded from
            this view
        filtered_fields (None): a set of field names of list fields that are
            filtered in this view
    """

    pass


class EvaluationPatchView(_PatchView):
    """A patch in an :class:`EvaluationPatchesView`.

    :class:`EvaluationPatchView` instances should not be created manually; they
    are generated by iterating over :class:`EvaluationPatchesView` instances.

    Args:
        doc: a :class:`fiftyone.core.odm.DatasetSampleDocument`
        view: the :class:`EvaluationPatchesView` that the patch belongs to
        selected_fields (None): a set of field names that this view is
            restricted to
        excluded_fields (None): a set of field names that are excluded from
            this view
        filtered_fields (None): a set of field names of list fields that are
            filtered in this view
    """

    pass


class _PatchesView(fov.DatasetView):
    __slots__ = (
        "_source_collection",
        "_patches_stage",
        "_patches_dataset",
        "__stages",
        "__media_type",
        "__name",
    )

    def __init__(
        self,
        source_collection,
        patches_stage,
        patches_dataset,
        _stages=None,
        _media_type=None,
        _name=None,
    ):
        if _stages is None:
            _stages = []

        self._source_collection = source_collection
        self._patches_stage = patches_stage
        self._patches_dataset = patches_dataset
        self.__stages = _stages
        self.__media_type = _media_type
        self.__name = _name

    def __copy__(self):
        return self.__class__(
            self._source_collection,
            deepcopy(self._patches_stage),
            self._patches_dataset,
            _stages=deepcopy(self.__stages),
            _media_type=self.__media_type,
            _name=self.__name,
        )

    @property
    def _base_view(self):
        return self.__class__(
            self._source_collection,
            self._patches_stage,
            self._patches_dataset,
        )

    @property
    def _dataset(self):
        return self._patches_dataset

    @property
    def _root_dataset(self):
        return self._source_collection._root_dataset

    @property
    def _stages(self):
        return self.__stages

    @property
    def _all_stages(self):
        return (
            self._source_collection.view()._all_stages
            + [self._patches_stage]
            + self.__stages
        )

    @property
    def _id_field(self):
        if self._is_frames:
            return "frame_id"

        return "sample_id"

    @property
    def _label_fields(self):
        raise NotImplementedError("subclass must implement _label_fields")

    @property
    def name(self):
        return self.__name

    @property
    def is_saved(self):
        return self.__name is not None

    @property
    def media_type(self):
        if self.__media_type is not None:
            return self.__media_type

        return self._dataset.media_type

    def _set_name(self, name):
        self.__name = name

    def _set_media_type(self, media_type):
        self.__media_type = media_type

    def _tag_labels(self, tags, label_field, ids=None, label_ids=None):
        if label_field in self._label_fields:
            _ids = self.values("_" + self._id_field)

        _, label_ids = super()._tag_labels(
            tags, label_field, ids=ids, label_ids=label_ids
        )

        if label_field in self._label_fields:
            ids, label_ids = self._to_source_ids(label_field, _ids, label_ids)
            self._source_collection._tag_labels(
                tags, label_field, ids=ids, label_ids=label_ids
            )

    def _untag_labels(self, tags, label_field, ids=None, label_ids=None):
        if label_field in self._label_fields:
            _ids = self.values("_" + self._id_field)

        _, label_ids = super()._untag_labels(
            tags, label_field, ids=ids, label_ids=label_ids
        )

        if label_field in self._label_fields:
            ids, label_ids = self._to_source_ids(label_field, _ids, label_ids)
            self._source_collection._untag_labels(
                tags, label_field, ids=ids, label_ids=label_ids
            )

    def _to_source_ids(self, label_field, ids, label_ids):
        _, is_list_field = self._source_collection._get_label_field_root(
            label_field
        )

        if not is_list_field:
            return ids, label_ids

        id_map = defaultdict(list)
        for _id, _label_id in zip(ids, label_ids):
            if etau.is_container(_label_id):
                id_map[_id].extend(_label_id)
            else:
                id_map[_id].append(_label_id)

        if not id_map:
            return [], []

        return zip(*id_map.items())

    def set_values(self, field_name, *args, **kwargs):
        field = field_name.split(".", 1)[0]
        must_sync = field in self._label_fields

        # The `set_values()` operation could change the contents of this view,
        # so we first record the sample IDs that need to be synced
        if must_sync and self._stages:
            ids = self.values("id")
        else:
            ids = None

        super().set_values(field_name, *args, **kwargs)

        self._sync_source_field(field, ids=ids)
        self._sync_source_field_schema(field_name)

    def set_label_values(self, field_name, *args, **kwargs):
        field = field_name.split(".", 1)[0]
        must_sync = field in self._label_fields

        super().set_label_values(field_name, *args, **kwargs)

        if must_sync:
            _, root = self._get_label_field_path(field)
            _, src_root = self._source_collection._get_label_field_path(field)
            _field_name = src_root + field_name[len(root) :]

            self._source_collection.set_label_values(
                _field_name, *args, **kwargs
            )

    def save(self, fields=None):
        """Saves the patches in this view to the underlying dataset.

        If this view contains any additional fields that were not extracted
        from the underlying dataset, these fields are not saved.

        This method **does not** delete patches from the underlying dataset
        that this view excludes.

        .. note::

            This method is not a :class:`fiftyone.core.stages.ViewStage`;
            it immediately writes the requested changes to the underlying
            dataset.

        Args:
            fields (None): an optional field or list of fields to save. If
                specified, only these fields are overwritten
        """
        if etau.is_str(fields):
            fields = [fields]

        super().save(fields=fields)

        if fields is None:
            fields = self._label_fields
        else:
            fields = [l for l in fields if l in self._label_fields]

        self._sync_source(fields=fields)

    def keep(self):
        """Deletes all patches that are **not** in this view from the
        underlying dataset.

        .. note::

            This method is not a :class:`fiftyone.core.stages.ViewStage`;
            it immediately writes the requested changes to the underlying
            dataset.
        """

        # The `keep()` operation below will delete patches, so we must sync
        # deletions to the source dataset first
        self._sync_source(update=False, delete=True)

        super().keep()

    def keep_fields(self):
        """Deletes all patch field(s) that have been excluded from this view
        from the underlying dataset.

        .. note::

            This method is not a :class:`fiftyone.core.stages.ViewStage`;
            it immediately writes the requested changes to the underlying
            dataset.
        """
        self._sync_source_keep_fields()

        super().keep_fields()

    def reload(self):
        """Reloads the view.

        Note that :class:`PatchView` instances are not singletons, so any
        in-memory patches extracted from this view will not be updated by
        calling this method.
        """
        self._source_collection.reload()

        # Regenerate the patches dataset
        _view = self._patches_stage.load_view(
            self._source_collection, reload=True
        )
        self._patches_dataset = _view._patches_dataset

        super().reload()

    def _delete_labels(self, labels, fields=None):
        patch_labels, other_labels, src_labels = self._parse_labels(
            labels, fields=fields
        )

        if patch_labels:
            patch_ids = [d["sample_id"] for d in patch_labels]
            self._patches_dataset.delete_samples(patch_ids)

        if other_labels:
            super()._delete_labels(other_labels, fields=fields)

        if src_labels:
            self._source_collection._delete_labels(src_labels, fields=fields)

    def _parse_labels(self, labels, fields=None):
        if etau.is_str(fields):
            fields = [fields]

        if fields is not None:
            labels = [d for d in labels if d["field"] in fields]

        label_fields = self._label_fields

        patch_labels = [d for d in labels if d["field"] in label_fields]
        other_labels = [d for d in labels if d["field"] not in label_fields]

        src_labels = deepcopy(patch_labels)
        if src_labels:
            patch_ids = [d["sample_id"] for d in src_labels]
            sample_ids = self._map_values(patch_ids, "id", "sample_id")
            for d, sample_id in zip(src_labels, sample_ids):
                d["sample_id"] = sample_id

        if len(label_fields) != 1:
            other_labels += patch_labels
            patch_labels = None

        return patch_labels, other_labels, src_labels

    def _sync_source_sample(self, sample):
        for field in self._label_fields:
            self._sync_source_sample_field(sample, field)

    def _sync_source_sample_field(self, sample, field):
        label_type = self._patches_dataset._get_label_field_type(field)
        is_list_field = issubclass(label_type, fol._HasLabelList)

        sample_id = sample[self._id_field]

        doc = sample._doc.field_to_mongo(field)
        if is_list_field:
            doc = doc[label_type._LABEL_LIST_FIELD]

        self._source_collection._set_labels(field, [sample_id], [doc])

    def _sync_source(self, fields=None, ids=None, update=True, delete=False):
        if fields is not None:
            fields = [f for f in fields if f in self._label_fields]
            if not fields:
                return
        else:
            fields = self._label_fields

        for field in fields:
            self._sync_source_field(
                field, ids=ids, update=update, delete=delete
            )

    def _sync_source_field(self, field, ids=None, update=True, delete=False):
        if field not in self._label_fields:
            return

        _, label_path = self._get_label_field_path(field)

        if ids is not None:
            view = self._patches_dataset.select(ids)
        else:
            view = self._patches_dataset

        if update:
            sample_ids, docs = view.aggregate(
                [foa.Values(self._id_field), foa.Values(label_path, _raw=True)]
            )

            self._source_collection._set_labels(field, sample_ids, docs)

        if delete:
            label_id_path = label_path + ".id"
            self_ids = set(self.values(label_id_path, unwind=True))
            all_sample_ids, all_label_ids = self._patches_dataset.values(
                [self._id_field, label_id_path]
            )

            del_labels = []
            for sample_id, label_ids in zip(all_sample_ids, all_label_ids):
                if label_ids is None:
                    continue

                if not etau.is_container(label_ids):
                    label_ids = [label_ids]

                for label_id in label_ids:
                    if label_id not in self_ids:
                        del_labels.append(
                            {
                                "label_id": label_id,
                                "sample_id": sample_id,
                                "field": field,
                            }
                        )

            if del_labels:
                self._source_collection._delete_labels(
                    del_labels, fields=field
                )

    def _sync_source_field_schema(self, path):
        root = path.split(".", 1)[0]
        if root not in self._label_fields:
            return

        field = self.get_field(path)
        if field is None:
            return

        _, label_root = self._get_label_field_path(root)
        leaf = path[len(label_root) + 1 :]

        dst_dataset = self._source_collection._dataset
        _, dst_path = dst_dataset._get_label_field_path(root)
        dst_path += "." + leaf

        dst_dataset._merge_sample_field_schema({dst_path: field})

        if self._source_collection._is_generated:
            self._source_collection._sync_source_field_schema(dst_path)

    def _sync_source_keep_fields(self):
        src_schema = self.get_field_schema()

        del_fields = set(self._label_fields) - set(src_schema.keys())
        if del_fields:
            self._source_collection.exclude_fields(del_fields).keep_fields()


class PatchesView(_PatchesView):
    """A :class:`fiftyone.core.view.DatasetView` of patches from a
    :class:`fiftyone.core.dataset.Dataset`.

    Patches views contain an ordered collection of patch samples, each of which
    contains a subset of a sample of the parent dataset corresponding to a
    single object or logical grouping of objects.

    Patches retrieved from patches views are returned as :class:`PatchView`
    objects.

    Args:
        source_collection: the
            :class:`fiftyone.core.collections.SampleCollection` from which this
            view was created
        patches_stage: the :class:`fiftyone.core.stages.ToPatches` stage that
            defines how the patches were extracted
        patches_dataset: the :class:`fiftyone.core.dataset.Dataset` that serves
            the patches in this view
    """

    __slots__ = ("_patches_field",)

    def __init__(
        self,
        source_collection,
        patches_stage,
        patches_dataset,
        _stages=None,
        _media_type=None,
        _name=None,
    ):
        super().__init__(
            source_collection,
            patches_stage,
            patches_dataset,
            _stages=_stages,
            _media_type=_media_type,
            _name=_name,
        )

        self._patches_field = patches_stage.field

    @property
    def _sample_cls(self):
        return PatchView

    @property
    def _label_fields(self):
        return [self._patches_field]

    @property
    def patches_field(self):
        """The field from which the patches in this view were extracted."""
        return self._patches_field


class EvaluationPatchesView(_PatchesView):
    """A :class:`fiftyone.core.view.DatasetView` containing evaluation patches
    from a :class:`fiftyone.core.dataset.Dataset`.

    Evaluation patches views contain an ordered collection of evaluation
    examples, each of which contains the ground truth and/or predicted labels
    for a true positive, false positive, or false negative example from an
    evaluation run on the underlying dataset.

    Patches retrieved from patches views are returned as
    :class:`EvaluationPatchView` objects.

    Args:
        source_collection: the
            :class:`fiftyone.core.collections.SampleCollection` from which this
            view was created
        patches_stage: the :class:`fiftyone.core.stages.ToEvaluationPatches`
            stage that defines how the patches were extracted
        patches_dataset: the :class:`fiftyone.core.dataset.Dataset` that serves
            the patches in this view
    """

    __slots__ = ("_gt_field", "_pred_field")

    def __init__(
        self,
        source_collection,
        patches_stage,
        patches_dataset,
        _stages=None,
        _media_type=None,
        _name=None,
    ):
        super().__init__(
            source_collection,
            patches_stage,
            patches_dataset,
            _stages=_stages,
            _media_type=_media_type,
            _name=_name,
        )

        eval_key = patches_stage.eval_key
        eval_info = source_collection.get_evaluation_info(eval_key)
        self._gt_field = eval_info.config.gt_field
        self._pred_field = eval_info.config.pred_field

    @property
    def _sample_cls(self):
        return EvaluationPatchView

    @property
    def _label_fields(self):
        return [self._gt_field, self._pred_field]

    @property
    def gt_field(self):
        """The ground truth field for the evaluation patches in this view."""
        return self._gt_field

    @property
    def pred_field(self):
        """The predictions field for the evaluation patches in this view."""
        return self._pred_field


def make_patches_dataset(
    sample_collection,
    field,
    other_fields=None,
    keep_label_lists=False,
    include_indexes=False,
    name=None,
    persistent=False,
    _generated=False,
):
    """Creates a dataset that contains one sample per object patch in the
    specified field of the collection.

    A ``sample_id`` field will be added that records the sample ID from which
    each patch was taken.

    By default, fields other than ``field`` and the default sample fields will
    not be included in the returned dataset.

    Args:
        sample_collection: a
            :class:`fiftyone.core.collections.SampleCollection`
        field: the patches field, which must be of type
            :class:`fiftyone.core.labels.Detections`,
            :class:`fiftyone.core.labels.Polylines`, or
            :class:`fiftyone.core.labels.Keypoints`
        other_fields (None): controls whether fields other than ``field`` and
            the default sample fields are included. Can be any of the
            following:

            -   a field or list of fields to include
            -   ``True`` to include all other fields
            -   ``None``/``False`` to include no other fields
        keep_label_lists (False): whether to store the patches in label list
            fields of the same type as the input collection rather than using
            their single label variants
        include_indexes (False): whether to recreate any custom indexes on
            ``field`` and ``other_fields`` on the new dataset (True) or a list
            of specific indexes or index prefixes to recreate. By default, no
            custom indexes are recreated
        name (None): a name for the dataset
        persistent (False): whether the dataset should persist in the database
            after the session terminates

    Returns:
        a :class:`fiftyone.core.dataset.Dataset`
    """
    if sample_collection._is_frame_field(field):
        raise ValueError(
            "Frame label patches cannot be directly extracted; you must first "
            "convert your video dataset to frames via `to_frames()`"
        )

    fova.validate_collection(sample_collection, media_type=fom.IMAGE)

    if etau.is_str(other_fields):
        other_fields = [other_fields]

    is_frame_patches = sample_collection._is_frames
    patches_field = _get_patches_field(
        sample_collection, field, keep_label_lists
    )

    dataset = fod.Dataset(
        name=name,
        persistent=persistent,
        _patches=_generated,
        _frames=is_frame_patches and _generated,
    )
    dataset.media_type = fom.IMAGE
    dataset.add_sample_field("sample_id", fof.ObjectIdField)
    dataset.create_index("sample_id")

    if is_frame_patches:
        dataset.add_sample_field("frame_id", fof.ObjectIdField)
        dataset.add_sample_field("frame_number", fof.FrameNumberField)
        dataset.create_index("frame_id")
        dataset.create_index([("sample_id", 1), ("frame_number", 1)])

    keys = field.split(".")
    if len(keys) == 2:
        dataset.add_sample_field(
            keys[0],
            fof.EmbeddedDocumentField,
            embedded_doc_type=foo.DynamicEmbeddedDocument,
        )
    elif len(keys) > 2:
        raise ValueError(
            "Cannot create patches from nested field '%s' of depth %d > 2"
            % (field, len(keys))
        )

    dataset.add_sample_field(field, **foo.get_field_kwargs(patches_field))

    if other_fields:
        src_schema = sample_collection.get_field_schema()
        curr_schema = dataset.get_field_schema()

        if other_fields == True:
            other_fields = [f for f in src_schema if f not in curr_schema]

        add_fields = [f for f in other_fields if f not in curr_schema]
        add_schema = {k: v for k, v in src_schema.items() if k in add_fields}
        dataset._sample_doc_cls.merge_field_schema(add_schema)

    fod._clone_indexes_for_patches_view(
        sample_collection,
        dataset,
        patches_fields=[field],
        other_fields=other_fields,
        include_indexes=include_indexes,
    )

    _make_pretty_summary(dataset, is_frame_patches=is_frame_patches)

    patches_view = _make_patches_view(
        sample_collection,
        field,
        other_fields=other_fields,
        keep_label_lists=keep_label_lists,
    )
    _write_samples(dataset, patches_view)

    return dataset


def _get_patches_field(sample_collection, field_name, keep_label_lists):
    if keep_label_lists:
        return sample_collection.get_field(field_name)

    _, path = sample_collection._get_label_field_path(field_name)
    return sample_collection.get_field(path, leaf=True)


def make_evaluation_patches_dataset(
    sample_collection,
    eval_key,
    other_fields=None,
    include_indexes=False,
    name=None,
    persistent=False,
    _generated=False,
):
    """Creates a dataset based on the results of the evaluation with the given
    key that contains one sample for each true positive, false positive, and
    false negative example in the input collection, respectively.

    True positive examples will result in samples with both their ground truth
    and predicted fields populated, while false positive/negative examples will
    only have one of their corresponding predicted/ground truth fields
    populated, respectively.

    If multiple predictions are matched to a ground truth object (e.g., if the
    evaluation protocol includes a crowd attribute), then all matched
    predictions will be stored in the single sample along with the ground truth
    object.

    The returned dataset will also have top-level ``type`` and ``iou`` fields
    populated based on the evaluation results for that example, as well as a
    ``sample_id`` field recording the sample ID of the example, and a ``crowd``
    field if the evaluation protocol defines a crowd attribute.

    .. note::

        The returned dataset will contain patches for the contents of the input
        collection, which may differ from the view on which the ``eval_key``
        evaluation was performed. This may exclude some labels that were
        evaluated and/or include labels that were not evaluated.

        If you would like to see patches for the exact view on which an
        evaluation was performed, first call
        :meth:`load_evaluation_view() <fiftyone.core.collections.SampleCollection.load_evaluation_view>`
        to load the view and then convert to patches.

    Args:
        sample_collection: a
            :class:`fiftyone.core.collections.SampleCollection`
        eval_key: an evaluation key that corresponds to the evaluation of
            ground truth/predicted fields that are of type
            :class:`fiftyone.core.labels.Detections`,
            :class:`fiftyone.core.labels.Polylines`, or
            :class:`fiftyone.core.labels.Keypoints`
        other_fields (None): controls whether fields other than the
            ground truth/predicted fields and the default sample fields are
            included. Can be any of the following:

            -   a field or list of fields to include
            -   ``True`` to include all other fields
            -   ``None``/``False`` to include no other fields
        include_indexes (False): whether to recreate any custom indexes on the
            ground truth/predicted fields and ``other_fields`` on the new
            dataset (True) or a list of specific indexes or index prefixes to
            recreate. By default, no custom indexes are recreated
        name (None): a name for the dataset
        persistent (False): whether the dataset should persist in the database
            after the session terminates

    Returns:
        a :class:`fiftyone.core.dataset.Dataset`
    """
    # Parse evaluation info
    eval_info = sample_collection.get_evaluation_info(eval_key)
    pred_field = eval_info.config.pred_field
    gt_field = eval_info.config.gt_field
    if hasattr(eval_info.config, "iscrowd"):
        crowd_attr = eval_info.config.iscrowd
    else:
        crowd_attr = None

    is_frame_patches = sample_collection._is_frames

    if is_frame_patches:
        if not pred_field.startswith(sample_collection._FRAMES_PREFIX):
            raise ValueError(
                "Cannot extract evaluation patches for sample-level "
                "evaluation '%s' from a frames view" % eval_key
            )

        pred_field = pred_field[len(sample_collection._FRAMES_PREFIX) :]
        gt_field = gt_field[len(sample_collection._FRAMES_PREFIX) :]
    elif sample_collection._is_frame_field(pred_field):
        raise ValueError(
            "Frame evaluation patches cannot be directly extracted; you must "
            "first convert your video dataset to frames via `to_frames()`"
        )

    if etau.is_str(other_fields):
        other_fields = [other_fields]

    _gt_field = sample_collection.get_field(gt_field)
    _pred_field = sample_collection.get_field(pred_field)

    # Setup dataset with correct schema
    dataset = fod.Dataset(
        name=name,
        persistent=persistent,
        _patches=_generated,
        _frames=is_frame_patches and _generated,
    )
    dataset.media_type = fom.IMAGE
    dataset.add_sample_field("sample_id", fof.ObjectIdField)
    dataset.create_index("sample_id")

    if is_frame_patches:
        dataset.add_sample_field("frame_id", fof.ObjectIdField)
        dataset.add_sample_field("frame_number", fof.FrameNumberField)
        dataset.create_index("frame_id")
        dataset.create_index([("sample_id", 1), ("frame_number", 1)])

    dataset.add_sample_field(gt_field, **foo.get_field_kwargs(_gt_field))
    dataset.add_sample_field(pred_field, **foo.get_field_kwargs(_pred_field))

    if crowd_attr is not None:
        dataset.add_sample_field("crowd", fof.BooleanField)

    dataset.add_sample_field("type", fof.StringField)
    dataset.add_sample_field("iou", fof.FloatField)

    if other_fields:
        src_schema = sample_collection.get_field_schema()
        curr_schema = dataset.get_field_schema()

        if other_fields == True:
            other_fields = [f for f in src_schema if f not in curr_schema]

        add_fields = [f for f in other_fields if f not in curr_schema]
        add_schema = {k: v for k, v in src_schema.items() if k in add_fields}
        dataset._sample_doc_cls.merge_field_schema(add_schema)

    fod._clone_indexes_for_patches_view(
        sample_collection,
        dataset,
        patches_fields=[gt_field, pred_field],
        other_fields=other_fields,
        include_indexes=include_indexes,
    )

    _make_pretty_summary(dataset, is_frame_patches=is_frame_patches)

    # Add ground truth patches
    gt_view = _make_eval_view(
        sample_collection,
        eval_key,
        gt_field,
        other_fields=other_fields,
        crowd_attr=crowd_attr,
    )
    _write_samples(dataset, gt_view)

    # Merge matched predictions
    _merge_matched_labels(dataset, sample_collection, eval_key, pred_field)

    # Add unmatched predictions
    unmatched_pred_view = _make_eval_view(
        sample_collection,
        eval_key,
        pred_field,
        other_fields=other_fields,
        skip_matched=True,
    )
    _add_samples(dataset, unmatched_pred_view)

    return dataset


def _make_pretty_summary(dataset, is_frame_patches=False):
    if is_frame_patches:
        set_fields = [
            "id",
            "sample_id",
            "frame_id",
            "filepath",
            "frame_number",
        ]
    else:
        set_fields = ["id", "sample_id", "filepath"]

    all_fields = dataset._sample_doc_cls._fields_ordered
    pretty_fields = set_fields + [f for f in all_fields if f not in set_fields]
    dataset._sample_doc_cls._fields_ordered = tuple(pretty_fields)


def _make_patches_view(
    sample_collection, field, other_fields=None, keep_label_lists=False
):
    root, is_list_field = sample_collection._get_label_field_root(field)
    label_type = sample_collection._get_label_field_type(field)

    if not issubclass(label_type, _PATCHES_TYPES):
        raise ValueError(
            "Invalid label field type %s. Extracting patches is only "
            "supported for the following types: %s"
            % (label_type, _PATCHES_TYPES)
        )

    project = {
        "_id": False,
        "_media_type": True,
        "filepath": True,
        "metadata": True,
        "tags": True,
        "created_at": True,
        "last_modified_at": True,
        field + "._cls": True,
        root: True,
    }

    if "." in field:
        project[field.split(".", 1)[0] + "._cls"] = True  # embedded fields

    if other_fields is not None:
        for f in other_fields:
            project[f] = True
            if "." in f:
                project[f.split(".", 1)[0] + "._cls"] = True  # embedded fields

    if sample_collection._is_frames:
        project["_sample_id"] = True
        project["_frame_id"] = "$_id"
        project["frame_number"] = True
    else:
        project["_sample_id"] = "$_id"

    pipeline = [
        {"$project": project},
        {"$unwind": "$" + root},
        {"$addFields": {"_rand": {"$rand": {}}}},
        {"$addFields": {"_id": "$" + root + "._id"}},
    ]

    if keep_label_lists:
        pipeline.append({"$addFields": {root: ["$" + root]}})
    elif root != field:
        pipeline.append({"$addFields": {field: "$" + root}})

    return sample_collection.mongo(pipeline)


def _make_eval_view(
    sample_collection,
    eval_key,
    field,
    other_fields=None,
    skip_matched=False,
    crowd_attr=None,
):
    eval_type = field + "." + eval_key
    eval_id = field + "." + eval_key + "_id"
    eval_iou = field + "." + eval_key + "_iou"

    view = _make_patches_view(
        sample_collection, field, other_fields=other_fields
    )

    if skip_matched:
        view = view.mongo(
            [
                {
                    "$match": {
                        "$expr": {
                            "$or": [
                                {"$eq": ["$" + eval_id, _NO_MATCH_ID]},
                                {"$not": {"$gt": ["$" + eval_id, None]}},
                            ]
                        }
                    }
                }
            ]
        )

    view = view.mongo(
        [{"$addFields": {"type": "$" + eval_type, "iou": "$" + eval_iou}}]
    )

    if crowd_attr is not None:
        crowd_path1 = "$" + field + "." + crowd_attr

        # @todo can remove this when `Attributes` are deprecated
        crowd_path2 = "$" + field + ".attributes." + crowd_attr + ".value"

        view = view.mongo(
            [
                {
                    "$addFields": {
                        "crowd": {
                            "$cond": {
                                "if": {"$gt": [crowd_path1, None]},
                                "then": {"$toBool": crowd_path1},
                                "else": {
                                    "$cond": {
                                        "if": {"$gt": [crowd_path2, None]},
                                        "then": {"$toBool": crowd_path2},
                                        "else": None,
                                    }
                                },
                            }
                        }
                    }
                }
            ]
        )

    return _upgrade_labels(view, field)


def _upgrade_labels(view, field):
    tmp_field = "_" + field
    label_type = view._get_label_field_type(field)
    return view.mongo(
        [
            {"$addFields": {tmp_field: "$" + field}},
            {"$project": {field: False}},
            {
                "$addFields": {
                    field: {
                        "_cls": label_type.__name__,
                        label_type._LABEL_LIST_FIELD: ["$" + tmp_field],
                    }
                }
            },
            {"$project": {tmp_field: False}},
        ]
    )


def _merge_matched_labels(dataset, src_collection, eval_key, field):
    field_type = src_collection._get_label_field_type(field)

    list_field = field + "." + field_type._LABEL_LIST_FIELD
    eval_id = eval_key + "_id"
    eval_field = list_field + "." + eval_id

    src_collection._aggregate(
        post_pipeline=[
            {"$project": {list_field: True}},
            {"$unwind": "$" + list_field},
            {
                "$match": {
                    "$expr": {
                        "$and": [
                            {"$gt": ["$" + eval_field, None]},
                            {"$ne": ["$" + eval_field, _NO_MATCH_ID]},
                        ]
                    }
                }
            },
            {
                "$group": {
                    "_id": {"$toObjectId": "$" + eval_field},
                    "_labels": {"$push": "$" + list_field},
                }
            },
            {
                "$project": {
                    field: {
                        "_cls": field_type.__name__,
                        field_type._LABEL_LIST_FIELD: "$_labels",
                    }
                },
            },
            {
                "$merge": {
                    "into": dataset._sample_collection_name,
                    "on": "_id",
                    "whenMatched": "merge",
                    "whenNotMatched": "discard",
                }
            },
        ]
    )


def _write_samples(dataset, src_collection):
    src_collection._aggregate(
        detach_frames=True,
        detach_groups=True,
        post_pipeline=[
            {"$addFields": {"_dataset_id": dataset._doc.id}},
            {"$out": dataset._sample_collection_name},
        ],
    )


def _add_samples(dataset, src_collection):
    src_collection._aggregate(
        detach_frames=True,
        detach_groups=True,
        post_pipeline=[
            {"$addFields": {"_dataset_id": dataset._doc.id}},
            {
                "$merge": {
                    "into": dataset._sample_collection_name,
                    "on": "_id",
                    "whenMatched": "keepExisting",
                    "whenNotMatched": "insert",
                }
            },
        ],
    )
